import torch
from loss import *
from utils import *
from tqdm import tqdm


def train_func(cfg, dataloader, model, optimizer, criterion, criterion2, lamda=0, epoch=0):
    t_loss = []
    s_loss = []
    with torch.set_grad_enabled(True):
        model.train()
        for v_input, t_input, label, multi_label, sf in tqdm(dataloader):
            seq_len = torch.sum(torch.max(torch.abs(v_input), dim=2)[0] > 0, 1)
            v_input = v_input[:, :torch.max(seq_len), :]
            v_input = v_input.float().cuda(non_blocking=True)
            t_input = t_input.float().cuda(non_blocking=True)
            label = label.float().cuda(non_blocking=True)
            multi_label = multi_label.cuda(non_blocking=True)

            logits, _, _ = model(v_input, seq_len)

            loss1 = CLAS2(logits, label, seq_len, criterion, sf, v_input, epoch, cfg)
            loss = loss1 
            

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # class_output = linear_model(v_input)
            # linear_loss = criterion(class_output, label)

            # optimizer1.zero_grad()
            # linear_loss.backward()
            # optimizer1.step()

            t_loss.append(loss1)

    return sum(t_loss) / len(t_loss), sum(s_loss) / len(s_loss)
